{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install yolox --no-deps  # onnxruntime==1.8.0 is putdated, hence --no-deps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import gdown\n",
    "import torch\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from torchvision import transforms\n",
    "from ultralytics.utils import ops\n",
    "\n",
    "from yolox.exp import get_exp\n",
    "from yolox.utils import postprocess\n",
    "from yolox.utils.model_utils import fuse_model\n",
    "from boxmot import BotSort\n",
    "from boxmot.utils.ops import yolox_preprocess\n",
    "\n",
    "\n",
    "# Dictionary for YOLOX model weights URLs\n",
    "YOLOX_ZOO = {\n",
    "    'yolox_n.pt': 'https://drive.google.com/uc?id=1AoN2AxzVwOLM0gJ15bcwqZUpFjlDV1dX',\n",
    "    'yolox_s.pt': 'https://drive.google.com/uc?id=1uSmhXzyV1Zvb4TJJCzpsZOIcw7CCJLxj',\n",
    "    'yolox_m.pt': 'https://drive.google.com/uc?id=11Zb0NN_Uu7JwUd9e6Nk8o2_EUfxWqsun',\n",
    "    'yolox_l.pt': 'https://drive.google.com/uc?id=1XwfUuCBF4IgWBWK2H7oOhQgEj9Mrb3rz',\n",
    "    'yolox_x.pt': 'https://drive.google.com/uc?id=1P4mY0Yyd3PPTybgZkjMYhFri88nTmJX5',\n",
    "}\n",
    "\n",
    "# Preprocessing pipeline\n",
    "input_size = [800, 1440]\n",
    "device = torch.device('cpu')\n",
    "yolox_model = 'yolox_s.pt'\n",
    "yolox_model_path = Path(yolox_model)\n",
    "\n",
    "# Download model if not present\n",
    "if not yolox_model_path.exists():\n",
    "    gdown.download(YOLOX_ZOO[yolox_model], output=str(yolox_model_path), quiet=False)\n",
    "\n",
    "# Initialize YOLOX model\n",
    "exp = get_exp(None, 'yolox_s')\n",
    "exp.num_classes = 1\n",
    "ckpt = torch.load(yolox_model_path, map_location=device)\n",
    "\n",
    "model = exp.get_model()\n",
    "model.load_state_dict(ckpt[\"model\"])\n",
    "model = fuse_model(model).to(device).eval()\n",
    "\n",
    "# Initialize tracker\n",
    "tracker = BotSort(reid_weights=Path('osnet_x0_25_msmt17.pt'), device=device, half=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Video capture setup\n",
    "vid = cv2.VideoCapture(0)\n",
    "\n",
    "while True:\n",
    "    ret, frame = vid.read()\n",
    "    if not ret:\n",
    "        break\n",
    "\n",
    "    # Preprocess frame\n",
    "    frame_img, ratio = yolox_preprocess(frame, input_size=input_size)\n",
    "    frame_tensor = torch.Tensor(frame_img).unsqueeze(0).to(device)\n",
    "\n",
    "    # Detection with YOLOX\n",
    "    with torch.no_grad():\n",
    "        dets = model(frame_tensor)\n",
    "    dets = postprocess(dets, 1, 0.5, 0.7, class_agnostic=True)[0]\n",
    "\n",
    "    if dets is not None:\n",
    "        # Rescale coordinates from letterbox back to the original frame size\n",
    "        dets[:, 0] = (dets[:, 0]) / ratio\n",
    "        dets[:, 1] = (dets[:, 1]) / ratio\n",
    "        dets[:, 2] = (dets[:, 2]) / ratio\n",
    "        dets[:, 3] = (dets[:, 3]) / ratio\n",
    "        dets[:, 4] *= dets[:, 5]\n",
    "        dets = dets[:, [0, 1, 2, 3, 4, 6]].cpu().numpy()\n",
    "    else:\n",
    "        dets = np.empty((0, 6))\n",
    "\n",
    "    # Update tracker\n",
    "    res = tracker.update(dets, frame)\n",
    "\n",
    "    # Plot results and display\n",
    "    tracker.plot_results(frame, show_trajectories=True)\n",
    "    cv2.imshow('BoXMOT + YOLOX', frame)\n",
    "\n",
    "    if cv2.waitKey(1) & 0xFF == ord('q'):\n",
    "        break\n",
    "\n",
    "# Release resources\n",
    "vid.release()\n",
    "cv2.destroyAllWindows()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "boxmot-YDNZdsaB-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
